# A patch file to replace ImageNet with a dummy dataset.
# Use only for benchmarking purposes.

diff --git a/train.py b/train.py
index 6e3b058..8ddbcdd 100755
--- a/train.py
+++ b/train.py
@@ -61,6 +61,34 @@ except ImportError:
 torch.backends.cudnn.benchmark = True
 _logger = logging.getLogger('train')
 
+
+class DummyImageDataset(torch.utils.data.Dataset):
+    """Dummy dataset with synthetic images."""
+    _IMAGE_HEIGHT = 3072
+    _IMAGE_WIDTH = 2304
+
+    def __init__(self, num_images, num_classes):
+        import numpy as np
+        from PIL import Image
+        imarray = np.random.rand(self._IMAGE_HEIGHT, self._IMAGE_WIDTH, 3) * 255
+        self.img = Image.fromarray(imarray.astype('uint8')).convert('RGB')
+        self.num_images = num_images
+        self.num_classes = num_classes
+        self.transform = None
+        self.target_transform = None
+
+    def __len__(self):
+        return self.num_images
+
+    def __getitem__(self, idx):
+        if self.transform is not None:
+            img = self.transform(self.img)
+        target = idx % self.num_classes
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+        return img, target
+
+
 # The first arg parser parses out only the --config argument, this argument is used to
 # load a yaml file containing key-values that override the defaults for the main parser below
 config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
@@ -71,8 +99,6 @@ parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
 parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
 
 # Dataset parameters
-parser.add_argument('data_dir', metavar='DIR',
-                    help='path to dataset')
 parser.add_argument('--dataset', '-d', metavar='NAME', default='',
                     help='dataset type (default: ImageFolder/ImageTar if empty)')
 parser.add_argument('--train-split', metavar='NAME', default='train',
@@ -486,17 +512,8 @@ def main():
         _logger.info('Scheduled epochs: {}'.format(num_epochs))
 
     # create the train and eval datasets
-    dataset_train = create_dataset(
-        args.dataset, root=args.data_dir, split=args.train_split, is_training=True,
-        class_map=args.class_map,
-        download=args.dataset_download,
-        batch_size=args.batch_size,
-        repeats=args.epoch_repeats)
-    dataset_eval = create_dataset(
-        args.dataset, root=args.data_dir, split=args.val_split, is_training=False,
-        class_map=args.class_map,
-        download=args.dataset_download,
-        batch_size=args.batch_size)
+    dataset_train = DummyImageDataset(num_images=1231167, num_classes=1000)
+    dataset_eval = DummyImageDataset(num_images=50000, num_classes=1000)
 
     # setup mixup / cutmix
     collate_fn = None
